{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f431da81-4d94-4fd1-8a51-952ef824cb46",
   "metadata": {},
   "source": [
    "# Cifar Library and Demo\n",
    "\n",
    "[This is a library](https://pkg.go.dev/github.com/gomlx/gomlx/examples/cifar) to download and parse the Cifar datasets (Cifar-10 and Cifar-100), and a very small demo of a FNN (Feedforward Neural Network) with GoMLX. FNNs are notoriously bad for images, but it's only a demo. Look for the Resnet50 model for a more serious image classification model (old but still good -- best results as of the time of this writing is with ViT model).\n",
    "\n",
    "The CIFAR-10 and CIFAR-100 are labeled subsets of the 80 million tiny images dataset. They were collected by Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton. See more details in [its homepage here](https://www.cs.toronto.edu/~kriz/cifar.html).\n",
    "\n",
    "This notebook serves as documentation and example for the [github.com/gomlx/gomlx/examples/cifar](https://github.com/gomlx/gomlx/examples/cifar) library, and the demo code in one piece can be seen in [.../examples/cifar/demo/](https://github.com/gomlx/gomlx/tree/main/examples/cifar/demo)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd398a51-c10a-4d05-91f7-f978856a8d8b",
   "metadata": {},
   "source": [
    "## Environment Set Up\n",
    "\n",
    "Let's set up `go.mod` to use the local copy of GoMLX, so it can be developed jointly the dataset code with the model. That's often how data pre-processing and model code is developed together with experimentation.\n",
    "\n",
    "If you are not changing code, feel free to simply skip this cell. Or if you used a different directory for you projects, change it below.\n",
    "\n",
    "Notice the directory `${HOME}/Projects/gomlx` is where the GoMLX code is copied by default in [its Docker](https://hub.docker.com/repository/docker/janpfeifer/gomlx_jupyterlab/general)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "22abc207-e8b0-44ef-840d-fee3550cffca",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\t- Added replace rule for module \"github.com/gomlx/gomlx\" to local directory \"/home/janpf/Projects/gomlx\".\n"
     ]
    }
   ],
   "source": [
    "!*rm -f go.work && go work init && go work use . \"${HOME}/Projects/gomlx\"\n",
    "%goworkfix"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "443e9e56-ef12-4d04-b7bb-99478fa2b662",
   "metadata": {},
   "source": [
    "## Data Preparation\n",
    "\n",
    "### Downloading data files\n",
    "\n",
    "To download, uncompress and untar to the local directory, simply do the following. Notice if it's already downloaded in the given `--data` directory, it returns immediately."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a61fdd1a-74a0-46f2-9cfc-f79bfa37e298",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import (\n",
    "    \"github.com/gomlx/gomlx/examples/cifar\"\n",
    "    \"github.com/gomlx/gomlx/pkg/support/fsutil\"\n",
    "    \"github.com/janpfeifer/must\"\n",
    ")\n",
    "\n",
    "var flagDataDir = flag.String(\"data\", \"~/work/cifar\", \"Directory to cache downloaded and generated dataset files.\")\n",
    "\n",
    "func AssertDownloaded() {\n",
    "    *flagDataDir = must.M1(fsutil.ReplaceTildeInDir(*flagDataDir))\n",
    "    if !fsutil.MustFileExists(*flagDataDir) {\n",
    "        must.M(os.MkdirAll(*flagDataDir, 0777))\n",
    "    }\n",
    "\n",
    "    must.M(cifar.DownloadCifar10(*flagDataDir))\n",
    "    must.M(cifar.DownloadCifar100(*flagDataDir))\n",
    "}\n",
    "\n",
    "%%\n",
    "AssertDownloaded()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "345076ef-e2a3-4405-b47d-801c9ec62b91",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total 48K\n",
      "drwxr-x--- 2 janpf janpf 4.0K Oct 10 17:53 base_cnn_model\n",
      "drwxr-x--- 2 janpf janpf 4.0K Oct 10 17:52 base_fnn_model\n",
      "drwxr-x--- 2 janpf janpf 4.0K Nov 12  2024 base_kan_model\n",
      "drwxr-xr-x 2 janpf janpf 4.0K Jun  4  2009 cifar-10-batches-bin\n",
      "drwxr-xr-x 2 janpf janpf 4.0K Feb 20  2010 cifar-100-binary\n",
      "drwxr-x--- 2 janpf janpf 4.0K Aug  2  2024 cnn\n",
      "drwxr-x--- 2 janpf janpf 4.0K Aug  1  2024 cnn_layer\n",
      "drwxr-x--- 2 janpf janpf 4.0K Jul 31  2024 cnn_nonorm\n",
      "drwxr-x--- 2 janpf janpf 4.0K Jul 31  2024 fnn_batchnorm_0\n",
      "drwxr-x--- 2 janpf janpf 4.0K Aug  1  2024 fnn_layer\n",
      "drwxr-x--- 2 janpf janpf 4.0K Nov 12  2024 kan\n",
      "drwxr-x--- 2 janpf janpf 4.0K Nov 12  2024 r001\n"
     ]
    }
   ],
   "source": [
    "!ls -lh ~/work/cifar/"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93d4ae39-915c-478c-887f-b1c8c080865b",
   "metadata": {},
   "source": [
    "### Sample some images\n",
    "\n",
    "The `cifar.NewDataset` creates a `data.InMemoryDataset` that can be used both for training, evaluation, or just to sample a few examples, which we do below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0dabc63a-b918-4065-9b23-0c5a22446458",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<p>Samples Cifar-10</p>\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<table><tr>\n",
       "<td><figure style=\"padding:4px;text-align: center;\"><img width=\"64\" height=\"64\" src=\"\"><figcaption style=\"text-align: center;\">truck ([9])</figcaption></figure></td>\n",
       "<td><figure style=\"padding:4px;text-align: center;\"><img width=\"64\" height=\"64\" src=\"\"><figcaption style=\"text-align: center;\">frog ([6])</figcaption></figure></td>\n",
       "<td><figure style=\"padding:4px;text-align: center;\"><img width=\"64\" height=\"64\" src=\"\"><figcaption style=\"text-align: center;\">horse ([7])</figcaption></figure></td>\n",
       "<td><figure style=\"padding:4px;text-align: center;\"><img width=\"64\" height=\"64\" src=\"\"><figcaption style=\"text-align: center;\">truck ([9])</figcaption></figure></td>\n",
       "<td><figure style=\"padding:4px;text-align: center;\"><img width=\"64\" height=\"64\" src=\"\"><figcaption style=\"text-align: center;\">dog ([5])</figcaption></figure></td>\n",
       "<td><figure style=\"padding:4px;text-align: center;\"><img width=\"64\" height=\"64\" src=\"\"><figcaption style=\"text-align: center;\">bird ([2])</figcaption></figure></td>\n",
       "<td><figure style=\"padding:4px;text-align: center;\"><img width=\"64\" height=\"64\" src=\"\"><figcaption style=\"text-align: center;\">frog ([6])</figcaption></figure></td>\n",
       "<td><figure style=\"padding:4px;text-align: center;\"><img width=\"64\" height=\"64\" src=\"\"><figcaption style=\"text-align: center;\">automobile ([1])</figcaption></figure></td>\n",
       "</tr></table>\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<p>Samples Cifar-100</p>\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<table><tr>\n",
       "<td><figure style=\"padding:4px;text-align: center;\"><img width=\"64\" height=\"64\" src=\"\"><figcaption style=\"text-align: center;\">spider ([79])</figcaption></figure></td>\n",
       "<td><figure style=\"padding:4px;text-align: center;\"><img width=\"64\" height=\"64\" src=\"\"><figcaption style=\"text-align: center;\">lobster ([45])</figcaption></figure></td>\n",
       "<td><figure style=\"padding:4px;text-align: center;\"><img width=\"64\" height=\"64\" src=\"\"><figcaption style=\"text-align: center;\">hamster ([36])</figcaption></figure></td>\n",
       "<td><figure style=\"padding:4px;text-align: center;\"><img width=\"64\" height=\"64\" src=\"\"><figcaption style=\"text-align: center;\">trout ([91])</figcaption></figure></td>\n",
       "<td><figure style=\"padding:4px;text-align: center;\"><img width=\"64\" height=\"64\" src=\"\"><figcaption style=\"text-align: center;\">ray ([67])</figcaption></figure></td>\n",
       "<td><figure style=\"padding:4px;text-align: center;\"><img width=\"64\" height=\"64\" src=\"\"><figcaption style=\"text-align: center;\">bed ([5])</figcaption></figure></td>\n",
       "<td><figure style=\"padding:4px;text-align: center;\"><img width=\"64\" height=\"64\" src=\"\"><figcaption style=\"text-align: center;\">couch ([25])</figcaption></figure></td>\n",
       "<td><figure style=\"padding:4px;text-align: center;\"><img width=\"64\" height=\"64\" src=\"\"><figcaption style=\"text-align: center;\">bear ([3])</figcaption></figure></td>\n",
       "</tr></table>\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import (\n",
    "    \"strings\"\n",
    "    \"github.com/gomlx/gopjrt/dtypes\"\n",
    "    \"github.com/gomlx/gomlx/backends\"\n",
    "    \"github.com/gomlx/gomlx/examples/cifar\"\n",
    "    \"github.com/gomlx/gomlx/pkg/core/shapes\"\n",
    "    \"github.com/gomlx/gomlx/pkg/core/tensors/images\"\n",
    "    \"github.com/janpfeifer/gonb/gonbui\"\n",
    "\n",
    "    _ \"github.com/gomlx/gomlx/backends/default\"\n",
    ")\n",
    "\n",
    "var (\n",
    "    // Model DType, used everywhere.\n",
    "    DType = dtypes.Float32\n",
    ")\n",
    "\n",
    "// sampleToNotebook generates a sample of Cifar-10 and Cifar-100 in a GoNB Jupyter Notebook.\n",
    "func sampleToNotebook() {\n",
    "    // Load data into tensors.\n",
    "    backend := backends.MustNew()\n",
    "    ds10 := cifar.NewDataset(backend, \"Samples Cifar-10\", *flagDataDir, cifar.C10, DType, cifar.Train).Shuffle()\n",
    "    ds100 := cifar.NewDataset(backend, \"Samples Cifar-100\", *flagDataDir, cifar.C100, DType, cifar.Train).Shuffle()\n",
    "    sampleImages(ds10, 8, cifar.C10Labels)\n",
    "    sampleImages(ds100, 8, cifar.C100FineLabels)\n",
    "}\n",
    "\n",
    "// sampleTable generates and outputs one html table of samples, sampling rows x cols from the images/labels provided.\n",
    "func sampleImages(ds train.Dataset, numImages int, labelNames []string) {\n",
    "    gonbui.DisplayHTML(fmt.Sprintf(\"<p>%s</p>\\n\", ds.Name()))\n",
    "    \n",
    "    parts := make([]string, 0, numImages+5) // Leave last part empty.\n",
    "    parts = append(parts, \"<table><tr>\")\n",
    "    for ii := 0; ii < numImages; ii++ {\n",
    "        _, inputs, labels := must.M3(ds.Yield())\n",
    "        imgTensor := inputs[0]\n",
    "        img := images.ToImage().Single(imgTensor)\n",
    "        label := labels[0].Value().([]int64)\n",
    "        labelStr := labelNames[label[0]]\n",
    "    \n",
    "        imgSrc := must.M1(gonbui.EmbedImageAsPNGSrc(img))\n",
    "        size := imgTensor.Shape().Dimensions[0]\n",
    "        parts = append(\n",
    "            parts, \n",
    "            fmt.Sprintf(`<td><figure style=\"padding:4px;text-align: center;\"><img width=\"%d\" height=\"%d\" src=\"%s\">` + \n",
    "                        `<figcaption style=\"text-align: center;\">%s (%d)</figcaption></figure></td>`, \n",
    "                        size*2, size*2, imgSrc, labelStr, label),\n",
    "        )\n",
    "    }\n",
    "    parts = append(parts, \"</tr></table>\", \"\")\n",
    "    gonbui.DisplayHTML(strings.Join(parts, \"\\n\"))\n",
    "}\n",
    "\n",
    "%%\n",
    "AssertDownloaded()\n",
    "sampleToNotebook()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7fa8c06c-ede0-4cd1-80bd-9331422e3a63",
   "metadata": {},
   "source": [
    "## Training on Cifar-10\n",
    "\n",
    "### Models Support\n",
    "\n",
    "1. `flagModel` defines the model type, out of `validModels` options.\n",
    "1. `createDefaultContex` creates a context and set the default values for the CIFAR models. \n",
    "1. `contextFromSettings` uses `createDefaultContext` and incorporate changes passed by the `-set` flag.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c8d3c19e-af0c-4fef-a76c-dabb375972d0",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model types: [\"fnn\" \"kan\" \"cnn\"]\n",
      "Parameters set (-set): [\"batch_size\" \"fnn_num_hidden_layers\"]\n",
      "\t\"/activation\": (string) swish\n",
      "\t\"/adam_dtype\": (string) \n",
      "\t\"/adam_epsilon\": (float64) 1e-07\n",
      "\t\"/batch_size\": (int) 17\n",
      "\t\"/checkpoint\": (string) \n",
      "\t\"/dropout_rate\": (float64) 0\n",
      "\t\"/eval_batch_size\": (int) 200\n",
      "\t\"/fnn_dropout_rate\": (float64) -1\n",
      "\t\"/fnn_normalization\": (string) \n",
      "\t\"/fnn_num_hidden_layers\": (int) 12\n",
      "\t\"/fnn_num_hidden_nodes\": (int) 128\n",
      "\t\"/fnn_residual\": (bool) true\n",
      "\t\"/kan_bspline_degree\": (int) 2\n",
      "\t\"/kan_bspline_magnitude_l1\": (float64) 1e-05\n",
      "\t\"/kan_bspline_magnitude_l2\": (float64) 0\n",
      "\t\"/kan_discrete\": (bool) false\n",
      "\t\"/kan_discrete_softness\": (float64) 0.1\n",
      "\t\"/kan_num_hidden_layers\": (int) 4\n",
      "\t\"/kan_num_hidden_nodes\": (int) 64\n",
      "\t\"/kan_num_points\": (int) 10\n",
      "\t\"/kan_residual\": (bool) true\n",
      "\t\"/l1_regularization\": (float64) 1e-05\n",
      "\t\"/l2_regularization\": (float64) 1e-05\n",
      "\t\"/learning_rate\": (float64) 0.0001\n",
      "\t\"/model\": (string) fnn\n",
      "\t\"/normalization\": (string) none\n",
      "\t\"/num_checkpoints\": (int) 3\n",
      "\t\"/optimizer\": (string) adamw\n",
      "\t\"/plots\": (bool) true\n",
      "\t\"/train_steps\": (int) 3000\n"
     ]
    }
   ],
   "source": [
    "import (\n",
    "    \"flags\"\n",
    "    \n",
    "    \"github.com/gomlx/gomlx/pkg/ml/layers\"\n",
    "    \"github.com/gomlx/gomlx/pkg/ml/layers/fnn\"\n",
    "    \"github.com/gomlx/gomlx/pkg/ml/layers/kan\"\n",
    "    \"github.com/gomlx/gomlx/pkg/ml/layers/regularizers\"\n",
    "    \"github.com/gomlx/gomlx/ui/commandline\"\n",
    "    \"github.com/gomlx/gomlx/pkg/ml/train/optimizers\"\n",
    "    \"github.com/gomlx/gomlx/examples/cifar\"\n",
    "    \"github.com/gomlx/gomlx/pkg/ml/context\"\n",
    ")\n",
    "\n",
    "var (\n",
    "    // ValidModels is the list of model types supported.\n",
    "    ValidModels = []string{\"fnn\", \"kan\", \"cnn\"}\n",
    "\tflagEval      = flag.Bool(\"eval\", true, \"Whether to evaluate the model on the validation data in the end.\")\n",
    "\tflagVerbosity = flag.Int(\"verbosity\", 1, \"Level of verbosity, the higher the more verbose.\")\n",
    ")\n",
    "\n",
    "// settings is bound to a \"-set\" flag to be used to set context hyperparameters.\n",
    "var settings = commandline.CreateContextSettingsFlag(createDefaultContext(), \"set\")\n",
    "\n",
    "// createDefaultContext sets the context with default hyperparameters\n",
    "func createDefaultContext() *context.Context {\n",
    "\tctx := context.New()\n",
    "\tctx.RngStateReset()\n",
    "\tctx.SetParams(map[string]any{\n",
    "        // Model type to use: valid values are fnn, kan and cnn.\n",
    "\t\t\"model\":           cifar.C10ValidModels[0],\n",
    "\t\t\"checkpoint\":      \"\",\n",
    "\t\t\"num_checkpoints\": 3,\n",
    "\t\t\"train_steps\":     3000,\n",
    "\n",
    "\t\t// batch_size for training.\n",
    "\t\t\"batch_size\": 64,\n",
    "\n",
    "\t\t// eval_batch_size can be larger than training, it's more efficient.\n",
    "\t\t\"eval_batch_size\": 200,\n",
    "\n",
    "\t\t// \"plots\" trigger generating intermediary eval data for plotting, and if running in GoNB, to actually\n",
    "\t\t// draw the plot with Plotly.\n",
    "\t\tplotly.ParamPlots: true,\n",
    "\n",
    "\t\t// If \"normalization\" is set, it overrides \"fnn_normalization\" and \"cnn_normalization\".\n",
    "\t\tlayers.ParamNormalization: \"none\",\n",
    "\n",
    "\t\toptimizers.ParamOptimizer:           \"adamw\",\n",
    "\t\toptimizers.ParamLearningRate:        1e-4,\n",
    "\t\toptimizers.ParamAdamEpsilon:         1e-7,\n",
    "\t\toptimizers.ParamAdamDType:           \"\",\n",
    "\t\tactivations.ParamActivation:         \"swish\",\n",
    "\t\tlayers.ParamDropoutRate:             0.0,\n",
    "\t\tregularizers.ParamL2:                1e-5,\n",
    "\t\tregularizers.ParamL1:                1e-5,\n",
    "\n",
    "\t\t// FNN network parameters:\n",
    "\t\tfnn.ParamNumHiddenLayers: 8,\n",
    "\t\tfnn.ParamNumHiddenNodes:  128,\n",
    "\t\tfnn.ParamResidual:        true,\n",
    "\t\tfnn.ParamNormalization:   \"\",   // Set to none for no normalization, otherwise it falls back to layers.ParamNormalization.\n",
    "\t\tfnn.ParamDropoutRate:     -1.0, // Set to 0.0 for no dropout, otherwise it falls back to layers.ParamDropoutRate.\n",
    "\n",
    "\t\t// KAN network parameters:\n",
    "\t\tkan.ParamNumControlPoints:   10, // Number of control points\n",
    "\t\tkan.ParamNumHiddenNodes:     64,\n",
    "\t\tkan.ParamNumHiddenLayers:    4,\n",
    "\t\tkan.ParamBSplineDegree:      2,\n",
    "\t\tkan.ParamBSplineMagnitudeL1: 1e-5,\n",
    "\t\tkan.ParamBSplineMagnitudeL2: 0.0,\n",
    "\t\tkan.ParamDiscrete:           false,\n",
    "\t\tkan.ParamDiscreteSoftness:   0.1,\n",
    "        kan.ParamResidual:           true,\n",
    "\t})\n",
    "\treturn ctx\n",
    "}\n",
    "\n",
    "// ContextFromSettings is the default context (createDefaultContext) changed by -set flag.\n",
    "func ContextFromSettings() (ctx *context.Context, paramsSet []string) {\n",
    "    ctx = createDefaultContext()\n",
    "    paramsSet = must.M1(commandline.ParseContextSettings(ctx, *settings))\n",
    "    return\n",
    "}\n",
    "\n",
    "// Let's test that we can set hyperparameters by setting it in the \"-set\" flag:\n",
    "%% -set=\"batch_size=17;fnn_num_hidden_layers=12\"\n",
    "fmt.Printf(\"Model types: %q\\n\", cifar.C10ValidModels)\n",
    "ctx, parametersSet := ContextFromSettings()\n",
    "fmt.Printf(\"Parameters set (-set): %q\\n\", parametersSet)\n",
    "fmt.Println(commandline.SprintContextSettings(ctx))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f771ca6f-5912-4406-9f2a-a838fd7cd065",
   "metadata": {},
   "source": [
    "### Simple FNN model\n",
    "\n",
    "A trivial model, that can easily get to ~50% accuracy (a random model would do 10%), but hardly much more than that.\n",
    "\n",
    "Later we are going to define a CNN model to compare, and we just set a placeholder model here for now.\n",
    "\n",
    "> **Note**: \n",
    ">\n",
    "> * The code is here just to exemplify. We are actually using the same code from the [`cifar`](https://github.com/gomlx/gomlx/tree/main/examples/cifar) package."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0ac4e4da-2c3b-431c-9e54-00e31045785f",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Logits shape for batch_size=7: (Float32)[7 10]\n"
     ]
    }
   ],
   "source": [
    "import (\n",
    "    \"flags\"    \n",
    "    . \"github.com/gomlx/gomlx/pkg/core/graph\"\n",
    "    \"github.com/gomlx/gomlx/examples/cifar\"\n",
    "    \"github.com/gomlx/gomlx/pkg/ml/context\"\n",
    "    \"github.com/gomlx/gomlx/pkg/ml/train/optimizers\"\n",
    "    \"github.com/gomlx/gomlx/pkg/core/shapes\"\n",
    ")\n",
    "\n",
    "var _ = NewGraph  // Make sure the graph package is in use.\n",
    "\n",
    "// C10PlainModelGraph implements train.ModelFn, and returns the logit Node, given the input image.\n",
    "// It's a basic FNN (Feedforward Neural Network), so no convolutions. It is meant only as an example.\n",
    "func C10PlainModelGraph(ctx *context.Context, spec any, inputs []*graph.Node) []*graph.Node {\n",
    "    _ = spec  // Not needed, the input type is always the same.\n",
    "\tctx = ctx.In(\"model\")\n",
    "\tbatchedImages := inputs[0]\n",
    "\tbatchSize := batchedImages.Shape().Dimensions[0]\n",
    "\tlogits := graph.Reshape(batchedImages, batchSize, -1)\n",
    "\tnumClasses := len(cifar.C10Labels)\n",
    "\tmodelType := context.GetParamOr(ctx, \"model\", cifar.C10ValidModels[0])\n",
    "\tif modelType == \"kan\" {\n",
    "\t\t// Configuration of the KAN layer(s) use the context hyperparameters.\n",
    "\t\tlogits = kan.New(ctx, logits, numClasses).Done()\n",
    "\t} else {\n",
    "\t\t// Configuration of the FNN layer(s) use the context hyperparameters.\n",
    "\t\tlogits = fnn.New(ctx, logits, numClasses).Done()\n",
    "\t}\n",
    "\tlogits.AssertDims(batchSize, numClasses)\n",
    "\treturn []*graph.Node{logits}\n",
    "}\n",
    "\n",
    "%% -set=\"batch_size=7\"\n",
    "// Let's test that the logits are coming out with the right shape: we want [batch_size, 10], since there are 10 classes.\n",
    "AssertDownloaded()\n",
    "ctx, _ := ContextFromSettings()\n",
    "g := NewGraph(backends.MustNew(), \"placeholder\")\n",
    "batchSize := context.GetParamOr(ctx, \"batch_size\", int(100))\n",
    "logits := C10PlainModelGraph(ctx, nil, []*Node{Parameter(g, \"images\", shapes.Make(DType, batchSize, cifar.Height, cifar.Width, cifar.Depth))})\n",
    "fmt.Printf(\"Logits shape for batch_size=%d: %s\\n\", batchSize, logits[0].Shape())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aef7d6ee-ca2b-40c7-ac3d-4242ecd8ed1b",
   "metadata": {},
   "source": [
    "### Training Loop\n",
    "\n",
    "With a model function defined, we use the training loop create for the Cifar10.\n",
    "\n",
    "The trainer is provided in the [`cifar` package](https://github.com/gomlx/gomlx/tree/main/examples/cifar). It is straight forward (and almost the same for every different project) and does the following for us:\n",
    "\n",
    "- If a checkpoing is given (--checkpoint) and it has previously saved model, it loads hyperparmeters and trained variables.\n",
    "- Create trainer: with selected model function (see [Simple FNN model](#Simple-FNN-model) and [CNN model for Cifar10](#CNN-model-for-Cifar10) sections), optimizer, loss and metrics.\n",
    "- Create a `train.Loop` and attach to it a progressbar, a periodic checkpoint saver and a plotter (`--set=\"plots=true\"`).\n",
    "- Train the selected number of train steps.\n",
    "- Report results.\n",
    "\n",
    "Below we train 50 steps with the default settings just to check things are working."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d72336ef-d96d-4d98-a3a6-5df912be5f24",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Backend \"stablehlo\":\tstablehlo:cuda - PJRT \"cuda\" plugin (/home/janpf/.local/lib/gomlx/pjrt/pjrt_c_api_cuda_plugin.so) v0.76 [StableHLO]\n",
      "      \u001b[1m 100% [========================================] (2119 steps/s)\u001b[0m [step=49] [loss+=2.45] [~loss+=3.05] [~loss=2.84] [~acc=16.91%]                \n",
      "\t[Step 50] median train step: 386 microseconds\n",
      "\n",
      "Results on Validation:\n",
      "\tMean Loss+Regularization (#loss+): 2.44\n",
      "\tMean Loss (#loss): 2.23\n",
      "\tMean Accuracy (#acc): 20.19%\n",
      "Results on Training:\n",
      "\tMean Loss+Regularization (#loss+): 2.44\n",
      "\tMean Loss (#loss): 2.23\n",
      "\tMean Accuracy (#acc): 19.72%\n"
     ]
    }
   ],
   "source": [
    "var flagCheckpoint = flag.String(\"checkpoint\", \"\", \"Directory save and load checkpoints from. If left empty, no checkpoints are created.\")\n",
    "\n",
    "// trainModel with hyperparameters configured with `-set=...`.\n",
    "func trainModel() {\n",
    "    ctx, paramsSet := ContextFromSettings()\n",
    "    cifar.TrainCifar10Model(ctx, *flagDataDir, *flagCheckpoint, *flagEval, *flagVerbosity, paramsSet)\n",
    "}\n",
    "\n",
    "// Train 50 steps, only to test things are working. No plots.\n",
    "%% --set=\"train_steps=50;plots=false\"\n",
    "trainModel()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "300a8419-5679-42ba-9149-51c17c7c427d",
   "metadata": {},
   "source": [
    "### FNN Model Training\n",
    "\n",
    "Let's train the FNN for real this time. \n",
    "\n",
    "* **Note**: The FNN model quickly overfits to the data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4de9b879-4790-4aac-ace7-3c8a348ac961",
   "metadata": {},
   "outputs": [],
   "source": [
    "// Remove a previously trained model. Skip this if you want to continue training a previous model.\n",
    "!rm -rf ~/work/cifar/base_fnn_model  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "cf876f49-e737-4dfa-aa05-48aad3d7a128",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Backend \"stablehlo\":\tstablehlo:cuda - PJRT \"cuda\" plugin (/home/janpf/.local/lib/gomlx/pjrt/pjrt_c_api_cuda_plugin.so) v0.76 [StableHLO]\n",
      "Checkpointing model to \"/home/janpf/work/cifar/base_fnn_model\"\n"
     ]
    },
    {
     "data": {
      "text/html": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "      \u001b[1m   1% [........................................] (391 steps/s) [0s:2m5s]\u001b[0m [step=699] [loss+=1.93] [~loss+=1.98] [~loss=1.76] [~acc=36.94%]        "
     ]
    },
    {
     "data": {
      "text/html": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "      \u001b[1m 100% [========================================] (2285 steps/s)\u001b[0m [step=49999] [loss+=0.472] [~loss+=0.68] [~loss=0.454] [~acc=83.93%]        ]         \n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<p><b>Metric: accuracy</b></p>\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div id=\"7b3688c8\"></div>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<script charset=\"UTF-8\">\n",
       "(() => {\n",
       "\tconst src=\"https://cdn.plot.ly/plotly-2.34.0.min.js\";\n",
       "\tvar runJSFn = function(module) {\n",
       "\t\t\n",
       "\tif (!module) {\n",
       "\t\tmodule = window.Plotly;\n",
       "\t}\n",
       "\tlet data = JSON.parse('{\"data\":[{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Train: Moving Average Accuracy\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5200,6441,7930,9717,11861,14434,17522,21228,25675,31011,37414,45098,50000],\"y\":[0.2718499004840851,0.33310237526893616,0.3727899491786957,0.40385714173316956,0.4214908182621002,0.43983137607574463,0.45598113536834717,0.4709991216659546,0.4949967563152313,0.5142630934715271,0.5370389223098755,0.5564768314361572,0.5729963779449463,0.6066710352897644,0.6162790656089783,0.6383858919143677,0.6856624484062195,0.7130618691444397,0.7344680428504944,0.7787956595420837,0.823291540145874,0.8393142819404602]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Mean Accuracy on Training\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5200,6441,7930,9717,11861,14434,17522,21228,25675,31011,37414,45098,50000],\"y\":[0.2892400026321411,0.3479999899864197,0.40359997749328613,0.4195399880409241,0.4402399957180023,0.4468799829483032,0.46313998103141785,0.46173998713493347,0.5036999583244324,0.5111600160598755,0.5507799983024597,0.5665599703788757,0.5873599648475647,0.6113199591636658,0.6364399790763855,0.6594600081443787,0.6951599717140198,0.7319999933242798,0.7418799996376038,0.788379967212677,0.8439799547195435,0.8496399521827698]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Mean Accuracy on Validation\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5200,6441,7930,9717,11861,14434,17522,21228,25675,31011,37414,45098,50000],\"y\":[0.29090002179145813,0.34530001878738403,0.40570002794265747,0.4147000312805176,0.42260003089904785,0.43540000915527344,0.45190003514289856,0.43310001492500305,0.4723000228404999,0.4686000347137451,0.5001000165939331,0.5100000500679016,0.5139999985694885,0.5200000405311584,0.5188000202178955,0.5188000202178955,0.5246000289916992,0.5204000473022461,0.5164999961853027,0.5083000063896179,0.5190000534057617,0.5064000487327576]}],\"layout\":{\"legend\":{},\"title\":{\"text\":\"accuracy\"},\"xaxis\":{\"showgrid\":true,\"title\":{\"text\":\"Steps\"},\"type\":\"log\"},\"yaxis\":{\"showgrid\":true,\"type\":\"log\"}}}');\n",
       "\tmodule.newPlot('7b3688c8', data);\n",
       "\n",
       "\t}\n",
       "\t\n",
       "    if (typeof requirejs === \"function\") {\n",
       "        // Use RequireJS to load module.\n",
       "\t\tlet srcWithoutExtension = src.substring(0, src.lastIndexOf(\".js\"));\n",
       "        requirejs.config({\n",
       "            paths: {\n",
       "                'plotly': srcWithoutExtension\n",
       "            }\n",
       "        });\n",
       "        require(['plotly'], function(plotly) {\n",
       "            runJSFn(plotly)\n",
       "        });\n",
       "        return\n",
       "    }\n",
       "\n",
       "\tvar currentScripts = document.head.getElementsByTagName(\"script\");\n",
       "\tfor (const idx in currentScripts) {\n",
       "\t\tlet script = currentScripts[idx];\n",
       "\t\tif (script.src == src) {\n",
       "\t\t\trunJSFn(null);\n",
       "\t\t\treturn;\n",
       "\t\t}\n",
       "\t}\n",
       "\n",
       "\tvar script = document.createElement(\"script\");\n",
       "\n",
       "\tscript.charset = \"utf-8\";\n",
       "\t\n",
       "\tscript.src = src;\n",
       "\tscript.onload = script.onreadystatechange = function () { runJSFn(null); };\n",
       "\tdocument.head.appendChild(script);\t\n",
       "})();\n",
       "</script>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<p><b>Metric: loss</b></p>\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div id=\"da627516\"></div>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<script charset=\"UTF-8\">\n",
       "(() => {\n",
       "\tconst src=\"https://cdn.plot.ly/plotly-2.34.0.min.js\";\n",
       "\tvar runJSFn = function(module) {\n",
       "\t\t\n",
       "\tif (!module) {\n",
       "\t\tmodule = window.Plotly;\n",
       "\t}\n",
       "\tlet data = JSON.parse('{\"data\":[{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Train: Batch Loss+Regularization\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5200,6441,7930,9717,11861,14434,17522,21228,25675,31011,37414,45098,50000],\"y\":[2.3047239780426025,2.0565249919891357,1.8981355428695679,1.7886199951171875,1.6463298797607422,1.947914958000183,1.684484601020813,1.916835904121399,1.5099726915359497,1.514784574508667,1.5636131763458252,1.642513632774353,1.4374361038208008,1.1111009120941162,1.19437575340271,1.239422082901001,1.2347511053085327,1.28165602684021,0.7321380972862244,0.7985636591911316,0.5997051000595093,0.47181737422943115]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Train: Moving Average Loss+Regularization\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5200,6441,7930,9717,11861,14434,17522,21228,25675,31011,37414,45098,50000],\"y\":[2.3419415950775146,2.0864555835723877,1.972267746925354,1.8823868036270142,1.8464716672897339,1.7811715602874756,1.7278045415878296,1.6905312538146973,1.622129201889038,1.589526653289795,1.5079638957977295,1.4602669477462769,1.4129899740219116,1.3253189325332642,1.281771183013916,1.223850131034851,1.1003196239471436,1.0367711782455444,0.9583731889724731,0.8412587642669678,0.7297943830490112,0.6803523898124695]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Train: Moving Average Loss\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5200,6441,7930,9717,11861,14434,17522,21228,25675,31011,37414,45098,50000],\"y\":[2.125455617904663,1.8700706958770752,1.7560051679611206,1.6662753820419312,1.6305737495422363,1.5655443668365479,1.5124746561050415,1.4755487442016602,1.4075313806533813,1.375354290008545,1.2942408323287964,1.2469351291656494,1.199937105178833,1.112388014793396,1.0686962604522705,1.0103092193603516,0.8858439326286316,0.8208240866661072,0.7402878403663635,0.6204277276992798,0.5055776238441467,0.4540979266166687]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Mean Loss+Regularization on Training\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5200,6441,7930,9717,11861,14434,17522,21228,25675,31011,37414,45098,50000],\"y\":[2.1918601989746094,2.034148693084717,1.8932268619537354,1.8499062061309814,1.7898163795471191,1.761637568473816,1.731433391571045,1.711868405342102,1.6040719747543335,1.5801236629486084,1.4817596673965454,1.4320400953292847,1.3759907484054565,1.307309865951538,1.23677396774292,1.1808717250823975,1.0806201696395874,0.9868822693824768,0.9505578875541687,0.8223467469215393,0.6734500527381897,0.6603193879127502]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Mean Loss on Training\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5200,6441,7930,9717,11861,14434,17522,21228,25675,31011,37414,45098,50000],\"y\":[1.975412130355835,1.8178071975708008,1.6770073175430298,1.6338436603546143,1.5739741325378418,1.546060562133789,1.5161515474319458,1.496924638748169,1.3895082473754883,1.3659896850585938,1.2680485248565674,1.21871817111969,1.1629536151885986,1.0943665504455566,1.0236914157867432,0.9673190116882324,0.8661024570465088,0.770906388759613,0.7324333190917969,0.6014794707298279,0.44919249415397644,0.4340272545814514]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Mean Loss+Regularization on Validation\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5200,6441,7930,9717,11861,14434,17522,21228,25675,31011,37414,45098,50000],\"y\":[2.192713975906372,2.037653684616089,1.9061682224273682,1.8657385110855103,1.8201364278793335,1.8061549663543701,1.780536413192749,1.7838258743286133,1.6960991621017456,1.7012794017791748,1.6345783472061157,1.6267406940460205,1.6206631660461426,1.6032363176345825,1.6155810356140137,1.6454981565475464,1.6860487461090088,1.7482020854949951,1.904830813407898,2.0721731185913086,2.2476465702056885,2.369826078414917]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Mean Loss on Validation\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5200,6441,7930,9717,11861,14434,17522,21228,25675,31011,37414,45098,50000],\"y\":[1.9762659072875977,1.8213108777999878,1.689948320388794,1.6496760845184326,1.6042943000793457,1.5905776023864746,1.565255045890808,1.5688823461532593,1.4815359115600586,1.487146019935608,1.4208672046661377,1.4134190082550049,1.4076263904571533,1.3902925252914429,1.4024983644485474,1.4319450855255127,1.4715303182601929,1.5322257280349731,1.686706304550171,1.8513057231903076,2.0233888626098633,2.143533945083618]}],\"layout\":{\"legend\":{},\"title\":{\"text\":\"loss\"},\"xaxis\":{\"showgrid\":true,\"title\":{\"text\":\"Steps\"},\"type\":\"log\"},\"yaxis\":{\"showgrid\":true,\"type\":\"log\"}}}');\n",
       "\tmodule.newPlot('da627516', data);\n",
       "\n",
       "\t}\n",
       "\t\n",
       "    if (typeof requirejs === \"function\") {\n",
       "        // Use RequireJS to load module.\n",
       "\t\tlet srcWithoutExtension = src.substring(0, src.lastIndexOf(\".js\"));\n",
       "        requirejs.config({\n",
       "            paths: {\n",
       "                'plotly': srcWithoutExtension\n",
       "            }\n",
       "        });\n",
       "        require(['plotly'], function(plotly) {\n",
       "            runJSFn(plotly)\n",
       "        });\n",
       "        return\n",
       "    }\n",
       "\n",
       "\tvar currentScripts = document.head.getElementsByTagName(\"script\");\n",
       "\tfor (const idx in currentScripts) {\n",
       "\t\tlet script = currentScripts[idx];\n",
       "\t\tif (script.src == src) {\n",
       "\t\t\trunJSFn(null);\n",
       "\t\t\treturn;\n",
       "\t\t}\n",
       "\t}\n",
       "\n",
       "\tvar script = document.createElement(\"script\");\n",
       "\n",
       "\tscript.charset = \"utf-8\";\n",
       "\t\n",
       "\tscript.src = src;\n",
       "\tscript.onload = script.onreadystatechange = function () { runJSFn(null); };\n",
       "\tdocument.head.appendChild(script);\t\n",
       "})();\n",
       "</script>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\t[Step 50000] median train step: 343 microseconds\n",
      "\n",
      "Results on Validation:\n",
      "\tMean Loss+Regularization (#loss+): 2.37\n",
      "\tMean Loss (#loss): 2.14\n",
      "\tMean Accuracy (#acc): 50.64%\n",
      "Results on Training:\n",
      "\tMean Loss+Regularization (#loss+): 0.66\n",
      "\tMean Loss (#loss): 0.434\n",
      "\tMean Accuracy (#acc): 84.96%\n"
     ]
    }
   ],
   "source": [
    "%% --checkpoint=base_fnn_model --set=\"model=fnn;train_steps=50_000;plots=true\"\n",
    "trainModel()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "942fdb36-4fbe-4ccf-9958-c8b67b036ada",
   "metadata": {},
   "source": [
    "### CNN model for Cifar-10\n",
    "\n",
    "Let's now properly define a CNN model to compare.\n",
    "\n",
    "The model was built following a [Keras model in Kaggle (thanks @ektasharm)](https://www.kaggle.com/code/ektasharma/simple-cifar10-cnn-keras-code-with-88-accuracy),\n",
    "which provided hardcoded values for all layers of the model -- so it doesn't make use of the hyperparamters set in the context. It achieves \\~86% on the validation set, after 80,000 steps of batch size 64 (\\~100 epochs).\n",
    "\n",
    "Notice that since it uses batch normalization, the training process will, at the end, update the moving averages of mean and variance: this improves the running estimate the model keeps during training. This [interesting blog post](https://www.mindee.com/blog/batch-normalization) explains about it. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "e27dc442-c3aa-4ec5-98e9-e5321b474785",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Logits shape for batch_size=11: (Float32)[11 10]\n"
     ]
    }
   ],
   "source": [
    "// ConvolutionModelGraph implements train.ModelFn and returns the logit Node, given the input image.\n",
    "// It's a straight forward CNN (Convolution Neural Network) model.\n",
    "//\n",
    "// This is modeled after the Keras example in Kaggle:\n",
    "// https://www.kaggle.com/code/ektasharma/simple-cifar10-cnn-keras-code-with-88-accuracy\n",
    "// (Thanks @ektasharma)\n",
    "func ConvolutionModelGraph(ctx *context.Context, spec any, inputs []*Node) []*Node {\n",
    "\tctx = ctx.In(\"model\")\n",
    "\tbatchedImages := inputs[0]\n",
    "\tg := batchedImages.Graph()\n",
    "\tdtype := batchedImages.DType()\n",
    "\tbatchSize := batchedImages.Shape().Dimensions[0]\n",
    "\tlogits := batchedImages\n",
    "\n",
    "\tlayerIdx := 0\n",
    "\tnextCtx := func(name string) *context.Context {\n",
    "\t\tnewCtx := ctx.Inf(\"%03d_%s\", layerIdx, name)\n",
    "\t\tlayerIdx++\n",
    "\t\treturn newCtx\n",
    "\t}\n",
    "\n",
    "\tlogits = layers.Convolution(nextCtx(\"conv\"), logits).Filters(32).KernelSize(3).PadSame().Done()\n",
    "\tlogits.AssertDims(batchSize, 32, 32, 32)\n",
    "\tlogits = activations.Relu(logits)\n",
    "\tlogits = batchnorm.New(nextCtx(\"batchnorm\"), logits, -1).Done()\n",
    "\tlogits = layers.Convolution(nextCtx(\"conv\"), logits).Filters(32).KernelSize(3).PadSame().Done()\n",
    "\tlogits = activations.Relu(logits)\n",
    "\tlogits = batchnorm.New(nextCtx(\"batchnorm\"), logits, -1).Done()\n",
    "\tlogits = MaxPool(logits).Window(2).Done()\n",
    "\tlogits = layers.DropoutNormalize(nextCtx(\"dropout\"), logits, Scalar(g, dtype, 0.3), true)\n",
    "\tlogits.AssertDims(batchSize, 16, 16, 32)\n",
    "\n",
    "\tlogits = layers.Convolution(nextCtx(\"conv\"), logits).Filters(64).KernelSize(3).PadSame().Done()\n",
    "\tlogits.AssertDims(batchSize, 16, 16, 64)\n",
    "\tlogits = activations.Relu(logits)\n",
    "\tlogits = batchnorm.New(nextCtx(\"batchnorm\"), logits, -1).Done()\n",
    "\tlogits = layers.Convolution(nextCtx(\"conv\"), logits).Filters(64).KernelSize(3).PadSame().Done()\n",
    "\tlogits.AssertDims(batchSize, 16, 16, 64)\n",
    "\tlogits = activations.Relu(logits)\n",
    "\tlogits = batchnorm.New(nextCtx(\"batchnorm\"), logits, -1).Done()\n",
    "\tlogits = MaxPool(logits).Window(2).Done()\n",
    "\tlogits = layers.DropoutNormalize(nextCtx(\"dropout\"), logits, Scalar(g, dtype, 0.5), true)\n",
    "\tlogits.AssertDims(batchSize, 8, 8, 64)\n",
    "\n",
    "\tlogits = layers.Convolution(nextCtx(\"conv\"), logits).Filters(128).KernelSize(3).PadSame().Done()\n",
    "\tlogits.AssertDims(batchSize, 8, 8, 128)\n",
    "\tlogits = activations.Relu(logits)\n",
    "\tlogits = batchnorm.New(nextCtx(\"batchnorm\"), logits, -1).Done()\n",
    "\tlogits = layers.Convolution(nextCtx(\"conv\"), logits).Filters(128).KernelSize(3).PadSame().Done()\n",
    "\tlogits.AssertDims(batchSize, 8, 8, 128)\n",
    "\tlogits = activations.Relu(logits)\n",
    "\tlogits = batchnorm.New(nextCtx(\"batchnorm\"), logits, -1).Done()\n",
    "\tlogits = MaxPool(logits).Window(2).Done()\n",
    "\tlogits = layers.DropoutNormalize(nextCtx(\"dropout\"), logits, Scalar(g, dtype, 0.5), true)\n",
    "\tlogits.AssertDims(batchSize, 4, 4, 128)\n",
    "\n",
    "\t// Flatten logits, and we can use the usual FNN/KAN.\n",
    "\tlogits = Reshape(logits, batchSize, -1)\n",
    "\tlogits = layers.Dense(nextCtx(\"dense\"), logits, true, 128)\n",
    "\tlogits = activations.Relu(logits)\n",
    "\tlogits = batchnorm.New(nextCtx(\"batchnorm\"), logits, -1).Done()\n",
    "\tlogits = layers.DropoutNormalize(nextCtx(\"dropout\"), logits, Scalar(g, dtype, 0.5), true)\n",
    "\tnumClasses := len(cifar.C10Labels)\n",
    "\tlogits = layers.Dense(nextCtx(\"dense\"), logits, true, numClasses)\n",
    "\treturn []*Node{logits}\n",
    "}\n",
    "\n",
    "%% -set=\"batch_size=11\"\n",
    "// Let's test that the logits are coming out with the right shape: we want [batch_size, 10], since there are 10 classes.\n",
    "AssertDownloaded()\n",
    "ctx, _ := ContextFromSettings()\n",
    "g := NewGraph(backends.MustNew(), \"placeholder\")\n",
    "batchSize := context.GetParamOr(ctx, \"batch_size\", int(100))\n",
    "logits := ConvolutionModelGraph(ctx, nil, []*Node{Parameter(g, \"images\", shapes.Make(DType, batchSize, cifar.Height, cifar.Width, cifar.Depth))})\n",
    "fmt.Printf(\"Logits shape for batch_size=%d: %s\\n\", batchSize, logits[0].Shape())\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e3b2b04-8ca7-49d6-921d-3527c9169da9",
   "metadata": {},
   "source": [
    "### Training the CNN model\n",
    "\n",
    "CNNs have a much better inductive bias for machine learning on images, and it can easily achieve > 80% accuracy in training data, and some less on validation data, due to overfitting.\n",
    "\n",
    "Likely it would benefit from pre-training the model on a larger unlabeled datasets -- see the \"Dogs vs Cats\" example to see transfer learning in action for an image model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "6af158e6-b9f9-423a-8ef9-3b54366d35ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "!rm -rf ~/work/cifar/base_cnn_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b2739281-6e70-4bd4-90e5-ed493678b8ce",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Backend \"stablehlo\":\tstablehlo:cuda - PJRT \"cuda\" plugin (/home/janpf/.local/lib/gomlx/pjrt/pjrt_c_api_cuda_plugin.so) v0.76 [StableHLO]\n",
      "Checkpointing model to \"/home/janpf/work/cifar/base_cnn_model\"\n"
     ]
    },
    {
     "data": {
      "text/html": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "      \u001b[1m   0% [........................................] (517 steps/s) [1s:2m33s]\u001b[0m [step=719] [loss+=1.66] [~loss+=1.67] [~loss=1.67] [~acc=37.29%]        "
     ]
    },
    {
     "data": {
      "text/html": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "      \u001b[1m 100% [========================================] (969 steps/s)\u001b[0m [step=79999] [loss+=0.822] [~loss+=0.63] [~loss=0.63] [~acc=78.56%]        28%]          \n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<p><b>Metric: accuracy</b></p>\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div id=\"294146f2\"></div>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<script charset=\"UTF-8\">\n",
       "(() => {\n",
       "\tconst src=\"https://cdn.plot.ly/plotly-2.34.0.min.js\";\n",
       "\tvar runJSFn = function(module) {\n",
       "\t\t\n",
       "\tif (!module) {\n",
       "\t\tmodule = window.Plotly;\n",
       "\t}\n",
       "\tlet data = JSON.parse('{\"data\":[{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Train: Moving Average Accuracy\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5200,6441,7930,9717,11861,14434,17522,21228,25675,31011,37414,45098,54319,65384,78662,80000],\"y\":[0.18372006714344025,0.2884494960308075,0.37545257806777954,0.44579750299453735,0.49114879965782166,0.5476747751235962,0.5808466672897339,0.6150075793266296,0.64619380235672,0.6479228138923645,0.6758214831352234,0.6941400170326233,0.7090026140213013,0.7260935306549072,0.7332726120948792,0.7508320808410645,0.7648523449897766,0.7698003649711609,0.7771881222724915,0.7890937924385071,0.7921884059906006,0.7865213751792908,0.7879019975662231,0.7946189045906067,0.7856067419052124]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Mean Accuracy on Training\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5200,6441,7930,9717,11861,14434,17522,21228,25675,31011,37414,45098,54319,65384,78662,80000],\"y\":[0.26802000403404236,0.3809199929237366,0.4485799968242645,0.5207799673080444,0.5472800135612488,0.5994600057601929,0.6157799959182739,0.6553599834442139,0.6785999536514282,0.6772399544715881,0.7376399636268616,0.7544999718666077,0.7712199687957764,0.7791199684143066,0.8071799874305725,0.8238799571990967,0.8251199722290039,0.836139976978302,0.8400200009346008,0.8558799624443054,0.8629199862480164,0.8643999695777893,0.861799955368042,0.8741999864578247,0.8674799799919128]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Mean Accuracy on Validation\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5200,6441,7930,9717,11861,14434,17522,21228,25675,31011,37414,45098,54319,65384,78662,80000],\"y\":[0.27580001950263977,0.38120001554489136,0.44700002670288086,0.5164000391960144,0.5414000153541565,0.5903000235557556,0.6097000241279602,0.6431000232696533,0.6607000231742859,0.6587000489234924,0.7148000597953796,0.7338000535964966,0.738800048828125,0.7465000152587891,0.766800045967102,0.7808000445365906,0.7807000279426575,0.787600040435791,0.7866000533103943,0.7968000173568726,0.8067000508308411,0.8017000555992126,0.8059000372886658,0.8133000135421753,0.8090000152587891]}],\"layout\":{\"legend\":{},\"title\":{\"text\":\"accuracy\"},\"xaxis\":{\"showgrid\":true,\"title\":{\"text\":\"Steps\"},\"type\":\"log\"},\"yaxis\":{\"showgrid\":true,\"type\":\"log\"}}}');\n",
       "\tmodule.newPlot('294146f2', data);\n",
       "\n",
       "\t}\n",
       "\t\n",
       "    if (typeof requirejs === \"function\") {\n",
       "        // Use RequireJS to load module.\n",
       "\t\tlet srcWithoutExtension = src.substring(0, src.lastIndexOf(\".js\"));\n",
       "        requirejs.config({\n",
       "            paths: {\n",
       "                'plotly': srcWithoutExtension\n",
       "            }\n",
       "        });\n",
       "        require(['plotly'], function(plotly) {\n",
       "            runJSFn(plotly)\n",
       "        });\n",
       "        return\n",
       "    }\n",
       "\n",
       "\tvar currentScripts = document.head.getElementsByTagName(\"script\");\n",
       "\tfor (const idx in currentScripts) {\n",
       "\t\tlet script = currentScripts[idx];\n",
       "\t\tif (script.src == src) {\n",
       "\t\t\trunJSFn(null);\n",
       "\t\t\treturn;\n",
       "\t\t}\n",
       "\t}\n",
       "\n",
       "\tvar script = document.createElement(\"script\");\n",
       "\n",
       "\tscript.charset = \"utf-8\";\n",
       "\t\n",
       "\tscript.src = src;\n",
       "\tscript.onload = script.onreadystatechange = function () { runJSFn(null); };\n",
       "\tdocument.head.appendChild(script);\t\n",
       "})();\n",
       "</script>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<p><b>Metric: loss</b></p>\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div id=\"bcc9bce5\"></div>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<script charset=\"UTF-8\">\n",
       "(() => {\n",
       "\tconst src=\"https://cdn.plot.ly/plotly-2.34.0.min.js\";\n",
       "\tvar runJSFn = function(module) {\n",
       "\t\t\n",
       "\tif (!module) {\n",
       "\t\tmodule = window.Plotly;\n",
       "\t}\n",
       "\tlet data = JSON.parse('{\"data\":[{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Train: Batch Loss+Regularization\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5200,6441,7930,9717,11861,14434,17522,21228,25675,31011,37414,45098,54319,65384,78662,80000],\"y\":[1.9373050928115845,1.740271806716919,1.608765721321106,1.4295493364334106,1.4091596603393555,1.3459317684173584,1.276282787322998,0.9538552165031433,1.1211469173431396,1.0575112104415894,0.8651746511459351,1.0079002380371094,0.945870041847229,0.7179523706436157,1.1049443483352661,0.9717222452163696,0.3986932635307312,0.5114235281944275,0.7100256681442261,0.4445623457431793,0.7852447032928467,0.6405561566352844,0.6460273861885071,0.4668564796447754,0.8220787048339844]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Train: Moving Average Loss+Regularization\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5200,6441,7930,9717,11861,14434,17522,21228,25675,31011,37414,45098,54319,65384,78662,80000],\"y\":[2.147197723388672,1.8619349002838135,1.6680160760879517,1.5029330253601074,1.3859893083572388,1.2567625045776367,1.1691949367523193,1.0817595720291138,1.0030250549316406,0.9882437586784363,0.907646894454956,0.8728981018066406,0.8384736180305481,0.7889049053192139,0.7585740685462952,0.7252750396728516,0.690473735332489,0.6712998747825623,0.6502280235290527,0.6272006630897522,0.6160404682159424,0.6397783756256104,0.626043975353241,0.6057541966438293,0.6298123598098755]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Train: Moving Average Loss\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5200,6441,7930,9717,11861,14434,17522,21228,25675,31011,37414,45098,54319,65384,78662,80000],\"y\":[2.147197723388672,1.8619349002838135,1.6680160760879517,1.5029330253601074,1.3859893083572388,1.2567625045776367,1.1691949367523193,1.0817595720291138,1.0030250549316406,0.9882437586784363,0.907646894454956,0.8728981018066406,0.8384736180305481,0.7889049053192139,0.7585740685462952,0.7252750396728516,0.690473735332489,0.6712998747825623,0.6502280235290527,0.6272006630897522,0.6160404682159424,0.6397783756256104,0.626043975353241,0.6057541966438293,0.6298123598098755]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Mean Loss+Regularization on Training\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5200,6441,7930,9717,11861,14434,17522,21228,25675,31011,37414,45098,54319,65384,78662,80000],\"y\":[1.9479873180389404,1.6887431144714355,1.4762221574783325,1.3028804063796997,1.2308348417282104,1.0973913669586182,1.047282099723816,0.9451701045036316,0.9107500314712524,0.910388708114624,0.73036128282547,0.6891148090362549,0.6405664086341858,0.6169004440307617,0.5417259335517883,0.5046582818031311,0.4987126290798187,0.45865583419799805,0.45154401659965515,0.4079904854297638,0.39691370725631714,0.39174652099609375,0.3909326493740082,0.3597249388694763,0.37641701102256775]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Mean Loss on Training\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5200,6441,7930,9717,11861,14434,17522,21228,25675,31011,37414,45098,54319,65384,78662,80000],\"y\":[1.9479873180389404,1.6887431144714355,1.4762221574783325,1.3028804063796997,1.2308348417282104,1.0973913669586182,1.047282099723816,0.9451701045036316,0.9107500314712524,0.910388708114624,0.73036128282547,0.6891148090362549,0.6405664086341858,0.6169004440307617,0.5417259335517883,0.5046582818031311,0.4987126290798187,0.45865583419799805,0.45154401659965515,0.4079904854297638,0.39691370725631714,0.39174652099609375,0.3909326493740082,0.3597249388694763,0.37641701102256775]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Mean Loss+Regularization on Validation\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5200,6441,7930,9717,11861,14434,17522,21228,25675,31011,37414,45098,54319,65384,78662,80000],\"y\":[1.9429305791854858,1.6821945905685425,1.4791098833084106,1.310433030128479,1.2543238401412964,1.1241520643234253,1.0784733295440674,0.9899420142173767,0.9700405597686768,0.9807898998260498,0.8136986494064331,0.7802551984786987,0.7508494257926941,0.741960883140564,0.6681351661682129,0.6396552324295044,0.6495312452316284,0.6357854008674622,0.6252809762954712,0.5922130942344666,0.5761780142784119,0.5982009172439575,0.6054201722145081,0.5713459849357605,0.5762442350387573]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Mean Loss on Validation\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5200,6441,7930,9717,11861,14434,17522,21228,25675,31011,37414,45098,54319,65384,78662,80000],\"y\":[1.9429305791854858,1.6821945905685425,1.4791098833084106,1.310433030128479,1.2543238401412964,1.1241520643234253,1.0784733295440674,0.9899420142173767,0.9700405597686768,0.9807898998260498,0.8136986494064331,0.7802551984786987,0.7508494257926941,0.741960883140564,0.6681351661682129,0.6396552324295044,0.6495312452316284,0.6357854008674622,0.6252809762954712,0.5922130942344666,0.5761780142784119,0.5982009172439575,0.6054201722145081,0.5713459849357605,0.5762442350387573]}],\"layout\":{\"legend\":{},\"title\":{\"text\":\"loss\"},\"xaxis\":{\"showgrid\":true,\"title\":{\"text\":\"Steps\"},\"type\":\"log\"},\"yaxis\":{\"showgrid\":true,\"type\":\"log\"}}}');\n",
       "\tmodule.newPlot('bcc9bce5', data);\n",
       "\n",
       "\t}\n",
       "\t\n",
       "    if (typeof requirejs === \"function\") {\n",
       "        // Use RequireJS to load module.\n",
       "\t\tlet srcWithoutExtension = src.substring(0, src.lastIndexOf(\".js\"));\n",
       "        requirejs.config({\n",
       "            paths: {\n",
       "                'plotly': srcWithoutExtension\n",
       "            }\n",
       "        });\n",
       "        require(['plotly'], function(plotly) {\n",
       "            runJSFn(plotly)\n",
       "        });\n",
       "        return\n",
       "    }\n",
       "\n",
       "\tvar currentScripts = document.head.getElementsByTagName(\"script\");\n",
       "\tfor (const idx in currentScripts) {\n",
       "\t\tlet script = currentScripts[idx];\n",
       "\t\tif (script.src == src) {\n",
       "\t\t\trunJSFn(null);\n",
       "\t\t\treturn;\n",
       "\t\t}\n",
       "\t}\n",
       "\n",
       "\tvar script = document.createElement(\"script\");\n",
       "\n",
       "\tscript.charset = \"utf-8\";\n",
       "\t\n",
       "\tscript.src = src;\n",
       "\tscript.onload = script.onreadystatechange = function () { runJSFn(null); };\n",
       "\tdocument.head.appendChild(script);\t\n",
       "})();\n",
       "</script>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\t[Step 80000] median train step: 916 microseconds\n",
      "\n",
      "Results on Validation:\n",
      "\tMean Loss+Regularization (#loss+): 0.576\n",
      "\tMean Loss (#loss): 0.576\n",
      "\tMean Accuracy (#acc): 80.90%\n",
      "Results on Training:\n",
      "\tMean Loss+Regularization (#loss+): 0.376\n",
      "\tMean Loss (#loss): 0.376\n",
      "\tMean Accuracy (#acc): 86.75%\n"
     ]
    }
   ],
   "source": [
    "%% --checkpoint=base_cnn_model --set=\"model=cnn;learning_rate=1e-3;l2_regularization=0;l1_regularization=0;train_steps=80000\"\n",
    "trainModel()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "924fd269-625f-4f16-852c-9796f6b34630",
   "metadata": {},
   "source": [
    "## Inference\n",
    "\n",
    "Inference, or serving the model, is done by using the same code as used to train the model.\n",
    "That is, currently the way to save the model is to export the Go model creation function, along with the checkpoint with learned weights.\n",
    "\n",
    "We created a small library `cifar/classifier` that takes an image as input, convert it to a tensor and calls the trained model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "f6db808a-dde2-4676-b1f4-f4b651c405f4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<p>Image: (32 x 32)</p>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAIAAAD8GO2jAAAI5UlEQVR4nCxWSZMc1dV988upMmvo6uqWWvrUgFDwGWEMAZhwEAZ7YXvrrX+CN1765+GVg43DCwIhbAdCotVTVeX88o33OarxWWfcOPnuGS77458+Q4xWy0U1r7wF75CZIEaOEUUIDcNQ1zWlNIBnjAYIR0dHi9VKTZr5iAMYBEJKjpCx02azmi/mVrtJaWttjLGc5WyVPAwYUc1FWFo3OuMRIt4CISgEO01T13Xeey54VVVlXiUsGXZtbxSO2Acvi9wo58dJqR7jQCjBETvnxnFUatrvt+ze8SOPgAkZMFijjI7WGIxpWeYyyX3wp6enXdcpbTCi3gTlh7Efsipbnhw5FLU1hUwV4DxPEMa73Y5i2rUDAHDO0zRjxTw/TMWx67px1M55732aJsYq65AQAgAwxk3TqVFvqoVHlAG8tb7nGL1otlab0werfC26vhGCpFnS1v00TTHGIs+r+Zw5ZExQ+6692e66znjviiIDIEoF57yU0vsAENNE1tt9GHTGkv9/6+2TxfHXL//VDQ1F5NWLl6v1ilJKCBFCIjTEiBjn5XzBOCdmckM/9Z1q26FpGoyJ9zBNpm0H52A+X3KeVOX8yfnDk9W8HdX8wfkHX/z+P5fbq6axMWDOt7vm6nKbJrM0LTHiQqSE8ojZvu2ud3vGuXDWq1FBgKIoZrNZCMFab61zzl9dXRNCnPOJkKcPz4NY/fKLP3zy+e+e/fMfTT/gFCOHY8RdNxjtACBCZIzneaGN2W13CCHW32FUIwAImXjvrbV3+0kRQtbasiwRQsoBYunx2fGgyYvXzRSinkxCeKCwXK5RjJeXN3mRZJls+65rO8EYBrRarch2ux3GASEUY0QIOXcg4r03+gBKKQAYY6xz2tim7V+8uLq57SxCRrv2tnHa+RDrup2UybNZ2w510xymeZCUIetYjHD3JtYYmxclQhhhhCKCGKdpAgCtdZ7nszIbVJhUNKO22iXFzFrnxyEC3e46PaknTx5vNqfW2n1XV/OKuIC073c7+vS9pz4AIgRjioE7G4zWd/SnEBzGiDGWSJ6nSSJTpT0iPMuzJCFGjRSR4EPbDYQy66wQsppXMYLTJjpPIDKC6Re/+e3q6LgoyxiIVbFv+6Ht1DCg6DECKXmepcH7u5cjL1/92PZ7jIMU+e3tHkFs673RBlHmnEsSWVYlAdTtmuAsZ0xKwYw2Zhh7ddDo0I6jGkMwjBFKCKOcYu7twZMR0/2ujtqCMeryqmWcU7IfBmsdwXgaRylZs9+1y/J0vTbjkCVilsk8TVj0h8QZ+8EZbf0Qo8UkMk6dA8Z5kc8PciIkUjpNbc6kHVWom5rhPE8XVdnvbiilpeSr1YJitLu+miUcYVBqEBQQOPr+06eT1sZo66zzh6AIHsVI8qzcbO4vl2uCGaUiIIIxAeNU01qtgEQEQRz+lKVSSEE5p3kmrZmGoevapqn3nGHOKBv7wXsPHhgmgifWxBABxwRFMQ6aURXv9Cvy9PTk/qWDtqmbtvENkpzYg9igH3pGQAoavJvUMIxBCi45JQQRghiR0gfoxskHiIitlqdt21tjtzdbNYyJkPTgZEcJdoTM8kwfLfuhmYYGryrOBCMoOQzzlCEpudYMvMMEECYBRxM8sQhdbbej0s5BkhTL5TocXKa8NUZNzXan+sFpjazptrcMoaPFIhUSeevN5K1eVLN79zZpLiY9dn0XABjjANH64AFjJpkZVfSO4sAoWVV5V98EZ4TgEMIwdgHc5vgYIeTDwRllWXIuAnhKKcYYIjjns0LM0CzLMn8HSuSkpjxPpJilScmCtlWaQsJWq2VRyLa+KQoJmEUUrFHD0GIMaZomqdjttk27n81m3tsqz7xzwzBmARWzo7IsASBJkhijFKm1Xggxn8+rcsaIhypL15t7y0X546sLisysSPvpwCWAizF0XTOOfV6kITgAb8xkjd4s5oSQpm4Q5ofgskoIwRhPEqmUgkOBCwBr3UTPVieUACPBWbXf3gBCiPB+0sPQoxhmZRERaDMJwU5PT4TkxugQfJ5lj99+ks/Kru+l5HW7k1IeAt7am9vX1zcX1ukAVqmeea9PVuuEQd/U2mqgiXa2GYaA/HozJ4RhErWZ8jJbHi+MMUmRQADwcFnfzBdL0bBtty3nM0cRZoCJK2a5dwaCF+yQpPTs3oYSoNHXTV0rjXliPLSjWi7K9apKk1SNkxrV6f2TYpYTSmMEhOKuri9eXxCKN5u1sdrYgy4IBs5JdAGFKIVMZJImKauOl5PXHBFeljISQqgzapYmp+t1Kgmj4ofvX2VptqzmyEPCeVrN67qOzj16cHa0WHCC33njzZffv8wR3SyP8iyJER0aUErOOaGUHW+OIvhU8izJqm6ajDXu8vH9B5vjlRCEUYmBfPfdv9948H/r9YoQIqV89uxZSunHn3xSzmbGmkzmD49Ov33+za8+/DjNJMKE/CRiAIQR++C9pygiQinngrP0xQ8/YEQ+//VnAD7EkMjs/Xd/8eWXf9vfbD96/+eUsn29d2r69KMP33v3Z9aarusikOLRo+fPv7m4uDg/f5im+U91K4TgnNO//uXP82qxXh1V1XK9Ptnebpfz6s3zc++sPHwiBePHR+tn33w9tF2WpV/9/SvB2aeffoQjaD3dbRIhQrWZur6hjNX7elJ6HIbt7c4YyyLCnDGEMYlYj8N6udA6ubi4MEYLISnVAA0h5PHbbzVte7PbdePw7vk7r68ujdGMHXomIpSm+f2zY4AgBItAnAdt7Diqvh/Zt8+/i3fAmNT7drFYtm3b9/2dJw8NhTEehgFhRBm9eH3x4OEDLuT1zbUxuqrmVVUyxjFGhHBKJaVM8JySO/cZjzFhXdcZY0IIMaLXF9fX1zdFUQghQggBwkEwMVprXfBZli2XyyRJtJ44EwRTSph3gWCGCfYu6EkbY+dzRDBTSmGMhBSMMfa/jUd8dnYWQqCUpmn604Hkvc/zXEoJKBZFcTiQlOIHylQIPJ9XhJAYo1JKa51lWTnLrLHWqjzPGWPe+/8GAAD//7swHNFia9prAAAAAElFTkSuQmCC"
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<p>Class: <b>horse (7)</b></p>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import (\n",
    "    \"encoding/base64\"\n",
    "    \"image/png\"\n",
    "    \n",
    "    \"github.com/gomlx/gomlx/examples/cifar/classifier\"\n",
    "    // We also must import then engine that will execute the model.\n",
    "    // Currently only XLA is supported.\n",
    "    _ \"github.com/gomlx/gomlx/backends/default\"\n",
    ")\n",
    "\n",
    "%%\n",
    "// Decode and print PNG image.\n",
    "imgBase64 := bytes.NewBufferString(\n",
    "    \"iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAIAAAD8GO2jAAAI5UlEQVR4nCxWSZMc1dV988upMmvo6uqWWvrUgFDwGWEMAZhwEAZ7YXvrrX+CN1765+GVg43DCwIhbAdCotVTVeX88o33OarxWWfcOPnuGS77458+Q4xWy0U1r7wF75CZIEaOEUUIDcNQ1zWlNIBnjAYIR0dHi9VKTZr5iAMYBEJKjpCx02azmi/mVrtJaWttjLGc5WyVPAwYUc1FWFo3OuMRIt4CISgEO01T13Xeey54VVVlXiUsGXZtbxSO2Acvi9wo58dJqR7jQCjBETvnxnFUatrvt+ze8SOPgAkZMFijjI7WGIxpWeYyyX3wp6enXdcpbTCi3gTlh7Efsipbnhw5FLU1hUwV4DxPEMa73Y5i2rUDAHDO0zRjxTw/TMWx67px1M55732aJsYq65AQAgAwxk3TqVFvqoVHlAG8tb7nGL1otlab0werfC26vhGCpFnS1v00TTHGIs+r+Zw5ZExQ+6692e66znjviiIDIEoF57yU0vsAENNE1tt9GHTGkv9/6+2TxfHXL//VDQ1F5NWLl6v1ilJKCBFCIjTEiBjn5XzBOCdmckM/9Z1q26FpGoyJ9zBNpm0H52A+X3KeVOX8yfnDk9W8HdX8wfkHX/z+P5fbq6axMWDOt7vm6nKbJrM0LTHiQqSE8ojZvu2ud3vGuXDWq1FBgKIoZrNZCMFab61zzl9dXRNCnPOJkKcPz4NY/fKLP3zy+e+e/fMfTT/gFCOHY8RdNxjtACBCZIzneaGN2W13CCHW32FUIwAImXjvrbV3+0kRQtbasiwRQsoBYunx2fGgyYvXzRSinkxCeKCwXK5RjJeXN3mRZJls+65rO8EYBrRarch2ux3GASEUY0QIOXcg4r03+gBKKQAYY6xz2tim7V+8uLq57SxCRrv2tnHa+RDrup2UybNZ2w510xymeZCUIetYjHD3JtYYmxclQhhhhCKCGKdpAgCtdZ7nszIbVJhUNKO22iXFzFrnxyEC3e46PaknTx5vNqfW2n1XV/OKuIC073c7+vS9pz4AIgRjioE7G4zWd/SnEBzGiDGWSJ6nSSJTpT0iPMuzJCFGjRSR4EPbDYQy66wQsppXMYLTJjpPIDKC6Re/+e3q6LgoyxiIVbFv+6Ht1DCg6DECKXmepcH7u5cjL1/92PZ7jIMU+e3tHkFs673RBlHmnEsSWVYlAdTtmuAsZ0xKwYw2Zhh7ddDo0I6jGkMwjBFKCKOcYu7twZMR0/2ujtqCMeryqmWcU7IfBmsdwXgaRylZs9+1y/J0vTbjkCVilsk8TVj0h8QZ+8EZbf0Qo8UkMk6dA8Z5kc8PciIkUjpNbc6kHVWom5rhPE8XVdnvbiilpeSr1YJitLu+miUcYVBqEBQQOPr+06eT1sZo66zzh6AIHsVI8qzcbO4vl2uCGaUiIIIxAeNU01qtgEQEQRz+lKVSSEE5p3kmrZmGoevapqn3nGHOKBv7wXsPHhgmgifWxBABxwRFMQ6aURXv9Cvy9PTk/qWDtqmbtvENkpzYg9igH3pGQAoavJvUMIxBCi45JQQRghiR0gfoxskHiIitlqdt21tjtzdbNYyJkPTgZEcJdoTM8kwfLfuhmYYGryrOBCMoOQzzlCEpudYMvMMEECYBRxM8sQhdbbej0s5BkhTL5TocXKa8NUZNzXan+sFpjazptrcMoaPFIhUSeevN5K1eVLN79zZpLiY9dn0XABjjANH64AFjJpkZVfSO4sAoWVV5V98EZ4TgEMIwdgHc5vgYIeTDwRllWXIuAnhKKcYYIjjns0LM0CzLMn8HSuSkpjxPpJilScmCtlWaQsJWq2VRyLa+KQoJmEUUrFHD0GIMaZomqdjttk27n81m3tsqz7xzwzBmARWzo7IsASBJkhijFKm1Xggxn8+rcsaIhypL15t7y0X546sLisysSPvpwCWAizF0XTOOfV6kITgAb8xkjd4s5oSQpm4Q5ofgskoIwRhPEqmUgkOBCwBr3UTPVieUACPBWbXf3gBCiPB+0sPQoxhmZRERaDMJwU5PT4TkxugQfJ5lj99+ks/Kru+l5HW7k1IeAt7am9vX1zcX1ukAVqmeea9PVuuEQd/U2mqgiXa2GYaA/HozJ4RhErWZ8jJbHi+MMUmRQADwcFnfzBdL0bBtty3nM0cRZoCJK2a5dwaCF+yQpPTs3oYSoNHXTV0rjXliPLSjWi7K9apKk1SNkxrV6f2TYpYTSmMEhOKuri9eXxCKN5u1sdrYgy4IBs5JdAGFKIVMZJImKauOl5PXHBFeljISQqgzapYmp+t1Kgmj4ofvX2VptqzmyEPCeVrN67qOzj16cHa0WHCC33njzZffv8wR3SyP8iyJER0aUErOOaGUHW+OIvhU8izJqm6ajDXu8vH9B5vjlRCEUYmBfPfdv9948H/r9YoQIqV89uxZSunHn3xSzmbGmkzmD49Ov33+za8+/DjNJMKE/CRiAIQR++C9pygiQinngrP0xQ8/YEQ+//VnAD7EkMjs/Xd/8eWXf9vfbD96/+eUsn29d2r69KMP33v3Z9aarusikOLRo+fPv7m4uDg/f5im+U91K4TgnNO//uXP82qxXh1V1XK9Ptnebpfz6s3zc++sPHwiBePHR+tn33w9tF2WpV/9/SvB2aeffoQjaD3dbRIhQrWZur6hjNX7elJ6HIbt7c4YyyLCnDGEMYlYj8N6udA6ubi4MEYLISnVAA0h5PHbbzVte7PbdePw7vk7r68ujdGMHXomIpSm+f2zY4AgBItAnAdt7Diqvh/Zt8+/i3fAmNT7drFYtm3b9/2dJw8NhTEehgFhRBm9eH3x4OEDLuT1zbUxuqrmVVUyxjFGhHBKJaVM8JySO/cZjzFhXdcZY0IIMaLXF9fX1zdFUQghQggBwkEwMVprXfBZli2XyyRJtJ44EwRTSph3gWCGCfYu6EkbY+dzRDBTSmGMhBSMMfa/jUd8dnYWQqCUpmn604Hkvc/zXEoJKBZFcTiQlOIHylQIPJ9XhJAYo1JKa51lWTnLrLHWqjzPGWPe+/8GAAD//7swHNFia9prAAAAAElFTkSuQmCC\")\n",
    "imgPNG := must.M1(io.ReadAll(base64.NewDecoder(base64.StdEncoding, imgBase64)))\n",
    "img := must.M1(png.Decode(bytes.NewBuffer(imgPNG)))\n",
    "size := img.Bounds()\n",
    "gonbui.DisplayHTML(fmt.Sprintf(\"<p>Image: (%d x %d)</p>\", size.Dx(), size.Dy()))\n",
    "gonbui.DisplayPNG(imgPNG)\n",
    "\n",
    "// Classify:\n",
    "c10Classifier := must.M1(classifier.New(\"~/work/cifar/base_fnn_model\"))\n",
    "classID := must.M1(c10Classifier.Classify(img))\n",
    "className := cifar.C10Labels[classID]\n",
    "gonbui.DisplayHTML(fmt.Sprintf(\"<p>Class: <b>%s (%d)</b></p>\", className, classID))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df79459f-9aca-4f52-b244-0b6109a69b9d",
   "metadata": {},
   "source": [
    "### Generate a random image as base64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "898eaefe-5934-473e-a47c-6f01bfb81df0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<img width=\"64\" height=\"64\" src=\"\"/>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "// \n",
    "%%\n",
    "backend := backends.MustNew()\n",
    "ds := cifar.NewDataset(backend, \"Samples Cifar-10\", *flagDataDir, cifar.C10, DType, cifar.Train).Shuffle()\n",
    "_, inputs, _ := must.M3(ds.Yield())\n",
    "imgTensor := inputs[0]\n",
    "img := images.ToImage().Single(imgTensor)\n",
    "imgSrc := must.M1(gonbui.EmbedImageAsPNGSrc(img))\n",
    "// imgBase64 := imgSrc[22:]  // Strip the preamble for a <img> src tag.\n",
    "// fmt.Printf(\"%s\\n\\n\", imgBase64)\n",
    "size := imgTensor.Shape().Dimensions[0]\n",
    "gonbui.DisplayHTML(fmt.Sprintf(`<img width=\"%d\" height=\"%d\" src=\"%s\"/>`, size*2, size*2, imgSrc))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Go (gonb)",
   "language": "go",
   "name": "gonb"
  },
  "language_info": {
   "codemirror_mode": "",
   "file_extension": ".go",
   "mimetype": "text/x-go",
   "name": "go",
   "nbconvert_exporter": "",
   "pygments_lexer": "",
   "version": "go1.24.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
